{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Reinforcement Learning with Tensorflow Part 8: Asynchronus Advantage Actor-Critic (A3C)\n",
    "\n",
    "This iPython notebook includes an implementation of the [A3C algorithm](https://arxiv.org/pdf/1602.01783.pdf). In it we use A3C to solve a simple 3D Doom challenge using the [VizDoom engine](http://vizdoom.cs.put.edu.pl/). For more information on A3C, see the accompanying [Medium post](https://medium.com/p/c88f72a5e9f2/edit).\n",
    "\n",
    "This tutorial requires that VizDoom is installed. It can be easily obtained with:\n",
    "\n",
    "`pip install vizdoom`\n",
    "\n",
    "We also require `basic.wad` and `helper.py`, both of which are available from the [DeepRL-Agents github repo](https://github.com/awjuliani/DeepRL-Agents).\n",
    "\n",
    "While training is taking place, statistics on agent performance are available from Tensorboard. To launch it use:\n",
    "\n",
    "`tensorboard --logdir=worker_0:'./train_0',worker_1:'./train_1',worker_2:'./train_2',worker_3:'./train_3'`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "import multiprocessing\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import tensorflow.contrib.slim as slim\n",
    "import scipy.signal\n",
    "%matplotlib inline\n",
    "from helper import *\n",
    "from vizdoom import *\n",
    "\n",
    "from random import choice\n",
    "from time import sleep\n",
    "from time import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copies one set of variables to another.\n",
    "# Used to set worker network parameters to those of global network.\n",
    "def update_target_graph(from_scope,to_scope):\n",
    "    from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope)\n",
    "    to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope)\n",
    "\n",
    "    op_holder = []\n",
    "    for from_var,to_var in zip(from_vars,to_vars):\n",
    "        op_holder.append(to_var.assign(from_var))\n",
    "    return op_holder\n",
    "\n",
    "# Processes Doom screen image to produce cropped and resized image. \n",
    "def process_frame(frame):\n",
    "    s = frame[10:-10,30:-30]\n",
    "    s = scipy.misc.imresize(s,[84,84])\n",
    "    s = np.reshape(s,[np.prod(s.shape)]) / 255.0\n",
    "    return s\n",
    "\n",
    "# Discounting function used to calculate discounted returns.\n",
    "def discount(x, gamma):\n",
    "    return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]\n",
    "\n",
    "#Used to initialize weights for policy and value output layers\n",
    "def normalized_columns_initializer(std=1.0):\n",
    "    def _initializer(shape, dtype=None, partition_info=None):\n",
    "        out = np.random.randn(*shape).astype(np.float32)\n",
    "        out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))\n",
    "        return tf.constant(out)\n",
    "    return _initializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Actor-Critic Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AC_Network():\n",
    "    def __init__(self,s_size,a_size,scope,trainer):\n",
    "        with tf.variable_scope(scope):\n",
    "            #Input and visual encoding layers\n",
    "            self.inputs = tf.placeholder(shape=[None,s_size],dtype=tf.float32)\n",
    "            self.imageIn = tf.reshape(self.inputs,shape=[-1,84,84,1])\n",
    "            self.conv1 = slim.conv2d(activation_fn=tf.nn.elu,\n",
    "                inputs=self.imageIn,num_outputs=16,\n",
    "                kernel_size=[8,8],stride=[4,4],padding='VALID')\n",
    "            self.conv2 = slim.conv2d(activation_fn=tf.nn.elu,\n",
    "                inputs=self.conv1,num_outputs=32,\n",
    "                kernel_size=[4,4],stride=[2,2],padding='VALID')\n",
    "            hidden = slim.fully_connected(slim.flatten(self.conv2),256,activation_fn=tf.nn.elu)\n",
    "            \n",
    "            #Recurrent network for temporal dependencies\n",
    "            lstm_cell = tf.contrib.rnn.BasicLSTMCell(256,state_is_tuple=True)\n",
    "            c_init = np.zeros((1, lstm_cell.state_size.c), np.float32)\n",
    "            h_init = np.zeros((1, lstm_cell.state_size.h), np.float32)\n",
    "            self.state_init = [c_init, h_init]\n",
    "            c_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.c])\n",
    "            h_in = tf.placeholder(tf.float32, [1, lstm_cell.state_size.h])\n",
    "            self.state_in = (c_in, h_in)\n",
    "            rnn_in = tf.expand_dims(hidden, [0])\n",
    "            step_size = tf.shape(self.imageIn)[:1]\n",
    "            state_in = tf.contrib.rnn.LSTMStateTuple(c_in, h_in)\n",
    "            lstm_outputs, lstm_state = tf.nn.dynamic_rnn(\n",
    "                lstm_cell, rnn_in, initial_state=state_in, sequence_length=step_size,\n",
    "                time_major=False)\n",
    "            lstm_c, lstm_h = lstm_state\n",
    "            self.state_out = (lstm_c[:1, :], lstm_h[:1, :])\n",
    "            rnn_out = tf.reshape(lstm_outputs, [-1, 256])\n",
    "            \n",
    "            #Output layers for policy and value estimations\n",
    "            self.policy = slim.fully_connected(rnn_out,a_size,\n",
    "                activation_fn=tf.nn.softmax,\n",
    "                weights_initializer=normalized_columns_initializer(0.01),\n",
    "                biases_initializer=None)\n",
    "            self.value = slim.fully_connected(rnn_out,1,\n",
    "                activation_fn=None,\n",
    "                weights_initializer=normalized_columns_initializer(1.0),\n",
    "                biases_initializer=None)\n",
    "            \n",
    "            #Only the worker network need ops for loss functions and gradient updating.\n",
    "            if scope != 'global':\n",
    "                self.actions = tf.placeholder(shape=[None],dtype=tf.int32)\n",
    "                self.actions_onehot = tf.one_hot(self.actions,a_size,dtype=tf.float32)\n",
    "                self.target_v = tf.placeholder(shape=[None],dtype=tf.float32)\n",
    "                self.advantages = tf.placeholder(shape=[None],dtype=tf.float32)\n",
    "\n",
    "                self.responsible_outputs = tf.reduce_sum(self.policy * self.actions_onehot, [1])\n",
    "\n",
    "                #Loss functions\n",
    "                self.value_loss = 0.5 * tf.reduce_sum(tf.square(self.target_v - tf.reshape(self.value,[-1])))\n",
    "                self.entropy = - tf.reduce_sum(self.policy * tf.log(self.policy))\n",
    "                self.policy_loss = -tf.reduce_sum(tf.log(self.responsible_outputs)*self.advantages)\n",
    "                self.loss = 0.5 * self.value_loss + self.policy_loss - self.entropy * 0.01\n",
    "\n",
    "                #Get gradients from local network using local losses\n",
    "                local_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)\n",
    "                self.gradients = tf.gradients(self.loss,local_vars)\n",
    "                self.var_norms = tf.global_norm(local_vars)\n",
    "                grads,self.grad_norms = tf.clip_by_global_norm(self.gradients,40.0)\n",
    "                \n",
    "                #Apply local gradients to global network\n",
    "                global_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'global')\n",
    "                self.apply_grads = trainer.apply_gradients(zip(grads,global_vars))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Worker Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Worker():\n",
    "    def __init__(self,game,name,s_size,a_size,trainer,model_path,global_episodes):\n",
    "        self.name = \"worker_\" + str(name)\n",
    "        self.number = name        \n",
    "        self.model_path = model_path\n",
    "        self.trainer = trainer\n",
    "        self.global_episodes = global_episodes\n",
    "        self.increment = self.global_episodes.assign_add(1)\n",
    "        self.episode_rewards = []\n",
    "        self.episode_lengths = []\n",
    "        self.episode_mean_values = []\n",
    "        self.summary_writer = tf.summary.FileWriter(\"train_\"+str(self.number))\n",
    "\n",
    "        #Create the local copy of the network and the tensorflow op to copy global paramters to local network\n",
    "        self.local_AC = AC_Network(s_size,a_size,self.name,trainer)\n",
    "        self.update_local_ops = update_target_graph('global',self.name)        \n",
    "        \n",
    "        #The Below code is related to setting up the Doom environment\n",
    "        game.set_doom_scenario_path(\"basic.wad\") #This corresponds to the simple task we will pose our agent\n",
    "        game.set_doom_map(\"map01\")\n",
    "        game.set_screen_resolution(ScreenResolution.RES_160X120)\n",
    "        game.set_screen_format(ScreenFormat.GRAY8)\n",
    "        game.set_render_hud(False)\n",
    "        game.set_render_crosshair(False)\n",
    "        game.set_render_weapon(True)\n",
    "        game.set_render_decals(False)\n",
    "        game.set_render_particles(False)\n",
    "        game.add_available_button(Button.MOVE_LEFT)\n",
    "        game.add_available_button(Button.MOVE_RIGHT)\n",
    "        game.add_available_button(Button.ATTACK)\n",
    "        game.add_available_game_variable(GameVariable.AMMO2)\n",
    "        game.add_available_game_variable(GameVariable.POSITION_X)\n",
    "        game.add_available_game_variable(GameVariable.POSITION_Y)\n",
    "        game.set_episode_timeout(300)\n",
    "        game.set_episode_start_time(10)\n",
    "        game.set_window_visible(False)\n",
    "        game.set_sound_enabled(False)\n",
    "        game.set_living_reward(-1)\n",
    "        game.set_mode(Mode.PLAYER)\n",
    "        game.init()\n",
    "        self.actions = self.actions = np.identity(a_size,dtype=bool).tolist()\n",
    "        #End Doom set-up\n",
    "        self.env = game\n",
    "        \n",
    "    def train(self,rollout,sess,gamma,bootstrap_value):\n",
    "        rollout = np.array(rollout)\n",
    "        observations = rollout[:,0]\n",
    "        actions = rollout[:,1]\n",
    "        rewards = rollout[:,2]\n",
    "        next_observations = rollout[:,3]\n",
    "        values = rollout[:,5]\n",
    "        \n",
    "        # Here we take the rewards and values from the rollout, and use them to \n",
    "        # generate the advantage and discounted returns. \n",
    "        # The advantage function uses \"Generalized Advantage Estimation\"\n",
    "        self.rewards_plus = np.asarray(rewards.tolist() + [bootstrap_value])\n",
    "        discounted_rewards = discount(self.rewards_plus,gamma)[:-1]\n",
    "        self.value_plus = np.asarray(values.tolist() + [bootstrap_value])\n",
    "        advantages = rewards + gamma * self.value_plus[1:] - self.value_plus[:-1]\n",
    "        advantages = discount(advantages,gamma)\n",
    "\n",
    "        # Update the global network using gradients from loss\n",
    "        # Generate network statistics to periodically save\n",
    "        feed_dict = {self.local_AC.target_v:discounted_rewards,\n",
    "            self.local_AC.inputs:np.vstack(observations),\n",
    "            self.local_AC.actions:actions,\n",
    "            self.local_AC.advantages:advantages,\n",
    "            self.local_AC.state_in[0]:self.batch_rnn_state[0],\n",
    "            self.local_AC.state_in[1]:self.batch_rnn_state[1]}\n",
    "        v_l,p_l,e_l,g_n,v_n, self.batch_rnn_state,_ = sess.run([self.local_AC.value_loss,\n",
    "            self.local_AC.policy_loss,\n",
    "            self.local_AC.entropy,\n",
    "            self.local_AC.grad_norms,\n",
    "            self.local_AC.var_norms,\n",
    "            self.local_AC.state_out,\n",
    "            self.local_AC.apply_grads],\n",
    "            feed_dict=feed_dict)\n",
    "        return v_l / len(rollout),p_l / len(rollout),e_l / len(rollout), g_n,v_n\n",
    "        \n",
    "    def work(self,max_episode_length,gamma,sess,coord,saver):\n",
    "        episode_count = sess.run(self.global_episodes)\n",
    "        total_steps = 0\n",
    "        print (\"Starting worker \" + str(self.number))\n",
    "        with sess.as_default(), sess.graph.as_default():                 \n",
    "            while not coord.should_stop():\n",
    "                sess.run(self.update_local_ops)\n",
    "                episode_buffer = []\n",
    "                episode_values = []\n",
    "                episode_frames = []\n",
    "                episode_reward = 0\n",
    "                episode_step_count = 0\n",
    "                d = False\n",
    "                \n",
    "                self.env.new_episode()\n",
    "                s = self.env.get_state().screen_buffer\n",
    "                episode_frames.append(s)\n",
    "                s = process_frame(s)\n",
    "                rnn_state = self.local_AC.state_init\n",
    "                self.batch_rnn_state = rnn_state\n",
    "                while self.env.is_episode_finished() == False:\n",
    "                    #Take an action using probabilities from policy network output.\n",
    "                    a_dist,v,rnn_state = sess.run([self.local_AC.policy,self.local_AC.value,self.local_AC.state_out], \n",
    "                        feed_dict={self.local_AC.inputs:[s],\n",
    "                        self.local_AC.state_in[0]:rnn_state[0],\n",
    "                        self.local_AC.state_in[1]:rnn_state[1]})\n",
    "                    a = np.random.choice(a_dist[0],p=a_dist[0])\n",
    "                    a = np.argmax(a_dist == a)\n",
    "\n",
    "                    r = self.env.make_action(self.actions[a]) / 100.0\n",
    "                    d = self.env.is_episode_finished()\n",
    "                    if d == False:\n",
    "                        s1 = self.env.get_state().screen_buffer\n",
    "                        episode_frames.append(s1)\n",
    "                        s1 = process_frame(s1)\n",
    "                    else:\n",
    "                        s1 = s\n",
    "                        \n",
    "                    episode_buffer.append([s,a,r,s1,d,v[0,0]])\n",
    "                    episode_values.append(v[0,0])\n",
    "\n",
    "                    episode_reward += r\n",
    "                    s = s1                    \n",
    "                    total_steps += 1\n",
    "                    episode_step_count += 1\n",
    "                    \n",
    "                    # If the episode hasn't ended, but the experience buffer is full, then we\n",
    "                    # make an update step using that experience rollout.\n",
    "                    if len(episode_buffer) == 30 and d != True and episode_step_count != max_episode_length - 1:\n",
    "                        # Since we don't know what the true final return is, we \"bootstrap\" from our current\n",
    "                        # value estimation.\n",
    "                        v1 = sess.run(self.local_AC.value, \n",
    "                            feed_dict={self.local_AC.inputs:[s],\n",
    "                            self.local_AC.state_in[0]:rnn_state[0],\n",
    "                            self.local_AC.state_in[1]:rnn_state[1]})[0,0]\n",
    "                        v_l,p_l,e_l,g_n,v_n = self.train(episode_buffer,sess,gamma,v1)\n",
    "                        episode_buffer = []\n",
    "                        sess.run(self.update_local_ops)\n",
    "                    if d == True:\n",
    "                        break\n",
    "                                            \n",
    "                self.episode_rewards.append(episode_reward)\n",
    "                self.episode_lengths.append(episode_step_count)\n",
    "                self.episode_mean_values.append(np.mean(episode_values))\n",
    "                \n",
    "                # Update the network using the episode buffer at the end of the episode.\n",
    "                if len(episode_buffer) != 0:\n",
    "                    v_l,p_l,e_l,g_n,v_n = self.train(episode_buffer,sess,gamma,0.0)\n",
    "                                \n",
    "                    \n",
    "                # Periodically save gifs of episodes, model parameters, and summary statistics.\n",
    "                if episode_count % 5 == 0 and episode_count != 0:\n",
    "                    if self.name == 'worker_0' and episode_count % 25 == 0:\n",
    "                        time_per_step = 0.05\n",
    "                        images = np.array(episode_frames)\n",
    "                        make_gif(images,'./frames/image'+str(episode_count)+'.gif',\n",
    "                            duration=len(images)*time_per_step,true_image=True,salience=False)\n",
    "                    if episode_count % 250 == 0 and self.name == 'worker_0':\n",
    "                        saver.save(sess,self.model_path+'/model-'+str(episode_count)+'.cptk')\n",
    "                        print (\"Saved Model\")\n",
    "\n",
    "                    mean_reward = np.mean(self.episode_rewards[-5:])\n",
    "                    mean_length = np.mean(self.episode_lengths[-5:])\n",
    "                    mean_value = np.mean(self.episode_mean_values[-5:])\n",
    "                    summary = tf.Summary()\n",
    "                    summary.value.add(tag='Perf/Reward', simple_value=float(mean_reward))\n",
    "                    summary.value.add(tag='Perf/Length', simple_value=float(mean_length))\n",
    "                    summary.value.add(tag='Perf/Value', simple_value=float(mean_value))\n",
    "                    summary.value.add(tag='Losses/Value Loss', simple_value=float(v_l))\n",
    "                    summary.value.add(tag='Losses/Policy Loss', simple_value=float(p_l))\n",
    "                    summary.value.add(tag='Losses/Entropy', simple_value=float(e_l))\n",
    "                    summary.value.add(tag='Losses/Grad Norm', simple_value=float(g_n))\n",
    "                    summary.value.add(tag='Losses/Var Norm', simple_value=float(v_n))\n",
    "                    self.summary_writer.add_summary(summary, episode_count)\n",
    "\n",
    "                    self.summary_writer.flush()\n",
    "                if self.name == 'worker_0':\n",
    "                    sess.run(self.increment)\n",
    "                episode_count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "max_episode_length = 300\n",
    "gamma = .99 # discount rate for advantage estimation and reward discounting\n",
    "s_size = 7056 # Observations are greyscale frames of 84 * 84 * 1\n",
    "a_size = 3 # Agent can move Left, Right, or Fire\n",
    "load_model = False\n",
    "model_path = './model'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "if not os.path.exists(model_path):\n",
    "    os.makedirs(model_path)\n",
    "    \n",
    "#Create a directory to save episode playback gifs to\n",
    "if not os.path.exists('./frames'):\n",
    "    os.makedirs('./frames')\n",
    "\n",
    "with tf.device(\"/cpu:0\"): \n",
    "    global_episodes = tf.Variable(0,dtype=tf.int32,name='global_episodes',trainable=False)\n",
    "    trainer = tf.train.AdamOptimizer(learning_rate=1e-4)\n",
    "    master_network = AC_Network(s_size,a_size,'global',None) # Generate global network\n",
    "    num_workers = multiprocessing.cpu_count() # Set workers to number of available CPU threads\n",
    "    workers = []\n",
    "    # Create worker classes\n",
    "    for i in range(num_workers):\n",
    "        workers.append(Worker(DoomGame(),i,s_size,a_size,trainer,model_path,global_episodes))\n",
    "    saver = tf.train.Saver(max_to_keep=5)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    coord = tf.train.Coordinator()\n",
    "    if load_model == True:\n",
    "        print ('Loading Model...')\n",
    "        ckpt = tf.train.get_checkpoint_state(model_path)\n",
    "        saver.restore(sess,ckpt.model_checkpoint_path)\n",
    "    else:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        \n",
    "    # This is where the asynchronous magic happens.\n",
    "    # Start the \"work\" process for each worker in a separate threat.\n",
    "    worker_threads = []\n",
    "    for worker in workers:\n",
    "        worker_work = lambda: worker.work(max_episode_length,gamma,sess,coord,saver)\n",
    "        t = threading.Thread(target=(worker_work))\n",
    "        t.start()\n",
    "        sleep(0.5)\n",
    "        worker_threads.append(t)\n",
    "    coord.join(worker_threads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
